################################################################
#####          Main code for model computations    		   #####
#####          Date created: 06.09.2024                    #####
#####          Date last changed: 16.09.2025               #####
################################################################

#region ### 1. Load packages and directories ###

# Packages Pkg.add("DataFrames") Pkg.update("CSV")
using Pkg
using Plots, QuantEcon, LinearAlgebra, DataFrames, Interpolations, JLD, Optim, Kronecker, Roots, GLM, CSV
using StatsBase, NaNMath; nm=NaNMath
using SpecialFunctions
using GSL
using Parameters
using ProgressMeter
using Measures
using RData
using Random
using Loess
using Distributions
using StatsPlots, CairoMakie
using XLSX
using NLsolve
using Gaston
using LaTeXStrings

# Directories pwd()
cd(dirname(@__FILE__))
if Sys.iswindows()
	ddat = "data\\"
	dout = "output\\"
else
	ddat = "data/"
	dout = "output/"
end

#endregion 

#region ### 2. Load data & functions ###

# define constants needed throughout 
const beta = 0.95
const delta = 0.1 
const rental_rate = 1/beta + delta 
const vat_tax = 0.1 # VAT tax (official rate)
const profit_tax = 0.2 # Corporate income tax (approximately official rate; abstract from variation in marginal rates)

# define further parameters
gross_elasticities = DataFrame(CSV.File(joinpath(ddat,"gross_elasticities.csv"); header=1))

const alpha_gross = gross_elasticities.med_cap_share_1digit_gross[1]/(1-vat_tax)
const beta_gross = gross_elasticities.med_lab_share_1digit_gross[1]/(1-vat_tax)
const gamma_gross = gross_elasticities.med_mat_share_1digit_gross[1]
const eta_tilde = alpha_gross + beta_gross + gamma_gross
const elasticity_ratio = eta_tilde/(1-eta_tilde)

# Load Data
include("get_data.jl")

# Load Functions
include("function_main_helper.jl")
include("function_get_VF.jl")
include("function_get_distribution.jl")

#endregion 

#region ### 3. Baseline GE computation ###

# define baseline parameters 
baseline_param = define_param()

# Step 1: Get expected profits by state 
results_profits_baseline = compute_profits_grid_baseline(z_star_grid = z_star_grid, n_eps = 500, param = baseline_param)
sum(abs.(results_profits_baseline.expected_profits .- expected_profits_baseline)) # results are very similar, but not the same as in R 

# Step 2: find baseline VF
baseline_VF = get_VF_baseline(
	param = baseline_param, 
	guess_exp_VF = (1/(1-beta))*profits_NC_baseline, 
	exp_profits = results_profits_baseline.expected_profits, 
	verbose = true)

# interquartile range of exit probabilities: basically same as in R (as sanity check!) 
baseline_VF.exit_proba[630] # should be around 0.097
baseline_VF.exit_proba[787] # should be around 0.062

# can also back out implied entry cost (which is in line with free entry condition and w = 1) 
const entry_cost = sum(baseline_VF.VF .* z_star_uncond_distrib) # Unconditional value of entry (This is the one used in paper!) 

# Step 3: Find baseline stationary distribution
baseline_stationary_distrib = get_stationary_distrib(
	initial_distrib = z_star_uncond_distrib, 
	entry_distrib = z_star_uncond_distrib,
	exit_proba = baseline_VF.exit_proba, verbose = true, crit = 1e-12)

# Check that baseline stationary distribution same as in R: Basically true! 
sum(abs.(baseline_stationary_distrib.SS_distrib .- baseline_SS_distrib))

## Can now use SS_distrib to look at distribution of subsidies 

# look at distribution of subsidies 
sample_subsidies_index = sample(
	collect(1:500000), # Create indices from 1 to 500k
	Weights(results_profits_baseline.epsilon_proba_variation .* vec(repeat(baseline_stationary_distrib.SS_distrib, inner=500))), # use as weights the probability of epsilon * proba of z 
	10000) # get 10k elements 
sample_subsidies = results_profits_baseline.subsidy_variation[sample_subsidies_index]
sample_epsilon = results_profits_baseline.epsilon_variation[sample_subsidies_index]
cor(sample_subsidies, sample_epsilon) # Strongly negatively correlated: Higher epsilon, lower subsidies!! 
cor(sample_subsidies, vec(repeat(z_star_grid, inner=500))[sample_subsidies_index]) # Correlation of z & subsidies is positive 
describe(DataFrame(subsidies = sample_subsidies), :all, cols=:subsidies) # avg around 0.32, range from 0.07 to 0.55
CairoMakie.density(sample_subsidies) # density is bimodal: many around 0.15 & around 0.4-0.5 
# how much do firms spend on rent-seeking on average? 6.5% of total intermediate inputs (5 - 8.5% IQR)
sample_rent_seeking_share = results_profits_baseline.rent_seeking_share_variation[sample_subsidies_index]
describe(DataFrame(rent_seeking_share = sample_rent_seeking_share), :all, cols=:rent_seeking_share)

### Poliy Plot (Figure 4 in paper) ### 

### Add policy plots for two different values for epsilon (25th and 75th percentile) 
epsilon_policies = [quantile(results_profits_baseline.epsilon_variation, 0.25), quantile(results_profits_baseline.epsilon_variation, 0.75)]
results_policies_baseline = compute_policies_baseline(z_star_grid = z_star_grid, eps = epsilon_policies, param = baseline_param)

# Profits: 
policy_plot_profits = Plots.plot(results_policies_baseline.z_star_variation[:,1], log.(results_policies_baseline.profits_C_variation[:,1]), label="", lw=2, legend=false) # label="Low epsilon", 
Plots.plot!(policy_plot_profits, results_policies_baseline.z_star_variation[:,2], log.(results_policies_baseline.profits_C_variation[:,2]), label="", lw=2, line = :dash, legend=false)
Plots.xlabel!(policy_plot_profits, "Productivity " * string(L"z^{*}"))
Plots.ylabel!(policy_plot_profits, "log(Profits)")
# Plots.title!(policy_plot_profits, "Profits")

# Revenue:  
policy_plot_revenue = Plots.plot(results_policies_baseline.z_star_variation[:,1], log.(results_policies_baseline.revenue_C_variation[:,1]), label="", lw=2, legend=false)
Plots.plot!(policy_plot_revenue, results_policies_baseline.z_star_variation[:,2], log.(results_policies_baseline.revenue_C_variation[:,2]), label="", lw=2, line = :dash, legend=false)
Plots.xlabel!(policy_plot_revenue, "Productivity " * string(L"z^{*}"))
Plots.ylabel!(policy_plot_revenue, "log(Revenue)")
# Plots.title!(policy_plot_revenue, "Revenue")

# m_R (or rent-seeking share?) 
policy_plot_rentseeking = Plots.plot(results_policies_baseline.z_star_variation[:,1], (results_policies_baseline.rent_seeking_share_variation[:,1]), label="", lw=2, legend=false)
Plots.plot!(policy_plot_rentseeking, results_policies_baseline.z_star_variation[:,2], (results_policies_baseline.rent_seeking_share_variation[:,2]), label="", lw=2, line = :dash, legend=false)
Plots.xlabel!(policy_plot_rentseeking, "Productivity " * string(L"z^{*}"))
Plots.ylabel!(policy_plot_rentseeking, string(L"m_R / (m_R + m)"))
# Plots.title!(policy_plot_rentseeking, "Rent-seeking share")

# subsidies
policy_plot_subsidies = Plots.plot(results_policies_baseline.z_star_variation[:,1], (results_policies_baseline.subsidy_variation[:,1]), label="", lw=2, legend=false)
Plots.plot!(policy_plot_subsidies, results_policies_baseline.z_star_variation[:,2], (results_policies_baseline.subsidy_variation[:,2]), label="", lw=2, line = :dash, legend=false)
Plots.xlabel!(policy_plot_subsidies, "Productivity " * string(L"z^{*}"))
Plots.ylabel!(policy_plot_subsidies,"Subsidy rate")
# Plots.title!(policy_plot_subsidies, "Subsidy")

# Now create a dummy plot just for the shared legend
policy_plot_legend = Plots.plot(legend=:bottom, framestyle=:none, grid=false, ticks=false, showaxis=false, legendcolumns=2, legendfontsize=10)
Plots.plot!(policy_plot_legend, [NaN], label="Low ε", lw=2)
Plots.plot!(policy_plot_legend, [NaN], label="High ε", lw=2, line = :dash)

# Compose the final layout
policy_plot_main = Plots.plot(policy_plot_profits, policy_plot_revenue, policy_plot_rentseeking, policy_plot_subsidies, layout=(2,2), legend=false)
policy_plot_final = Plots.plot(policy_plot_main, policy_plot_legend, layout=@layout([a; b{0.04h}]), legend=:bottom, size=(900, 450), left_margin = 5mm)
savefig(policy_plot_final, joinpath(dout,"results_optimal_policies_baseline.png"))
savefig(policy_plot_final, joinpath(dout,"Figure4.pdf"))

### Compute baseline aggregates 

# Construct aggregates needed throughout 
baseline_aggregates_goods = compute_aggregates_goods(
	w = 1.0, distrib = baseline_stationary_distrib.SS_distrib,
	mass_firms = 1.0, # normalized in baseline economy 
	mass_entry = baseline_stationary_distrib.mass_entry, 
	within_period_choices = results_profits_baseline, 
	tax = vat_tax, param = baseline_param)

# about 0.16% of total output spend on rent-seeking, 0.3% of total intermediate spending (reported in paper!!) 
baseline_aggregates_goods.total_rent_seeking ./ baseline_aggregates_goods.total_output
baseline_aggregates_goods.total_rent_seeking / (baseline_aggregates_goods.productive_intermediates + baseline_aggregates_goods.total_rent_seeking)

# how much tax revenue spent on subsidizing connected firms? 34.7% 
baseline_aggregates_goods.total_subsidies / baseline_aggregates_goods.gross_tax_revenues
baseline_aggregates_goods.total_subsidies / baseline_aggregates_goods.net_govt_transfers # share over net govt transfers 
(baseline_aggregates_goods.total_subsidies + baseline_aggregates_goods.net_govt_transfers) / baseline_aggregates_goods.net_govt_transfers # share over net govt transfers 
baseline_aggregates_goods.total_subsidies / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)

# what is the share of entry costs in total profits (total entry costs > total_net_profits)
baseline_aggregates_goods.total_entry_costs/baseline_aggregates_goods.total_net_profits

# how important are indiv parts in HH income? Profits (14%), net govt transfers (21.5%), labor income (64.5%) 
baseline_aggregates_goods.total_net_profits / baseline_aggregates_goods.total_HH_income 
baseline_aggregates_goods.net_govt_transfers / baseline_aggregates_goods.total_HH_income
baseline_aggregates_goods.total_labor_demand / baseline_aggregates_goods.total_HH_income

# what is baseline govt budget spent on HHs (or other development objectives)? 
const baseline_govt_spending = baseline_aggregates_goods.net_govt_transfers 

# save aggregate labor supply (the labor demand that clears labor market in baseline SS, reported in Table 2)
const aggregate_labor_supply = baseline_aggregates_goods.total_labor_demand


# derive baseline z_grid 
x_bar_baseline = (baseline_aggregates_goods.total_output)^(1/baseline_param.sigma)
z_grid_baseline = (z_star_grid ./ x_bar_baseline).^(baseline_param.sigma/(baseline_param.sigma-1))
# what about HH income & consumption? 
baseline_aggregates_goods.total_HH_income
# Need initial assets first: set to K_SS in baseline equilibrium (can vary this later!) 
initial_assets = baseline_aggregates_goods.productive_capital # about 75% more than total VA output in baseline 
baseline_HH_consumption = (rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income
# 
## Back out all primitive parameters 
primitive_param = define_param(
	baseline_param, 
	beta_eps = baseline_param.beta_eps*((baseline_param.sigma - 1)/baseline_param.sigma), 
	alpha_eps = baseline_param.alpha_eps + baseline_param.beta_eps*(1/baseline_param.sigma)*log(baseline_aggregates_goods.total_output)
)
### Test: Check that can recover baseline equilibrium (with rent-seeking) 
# Check 1: make sure that profit grids are identical for baseline & primitives 
baseline_profits_grid = compute_profits_grid_baseline(z_star_grid = z_star_grid, n_eps = 500, param = baseline_param)
baseline_profits_grid_primitive = compute_profits_grid_cf(w = 1.0, Y = baseline_aggregates_goods.total_output, tax = nothing, n_eps = 500, param = primitive_param, z_grid = z_grid_baseline)
# Indeed the same! 
maximum((baseline_profits_grid.expected_profits_C .- baseline_profits_grid_primitive.expected_profits_C) ./ baseline_profits_grid_primitive.expected_profits_C)
## Check 2: Find VF for baseline combination of (w,Y)
test_VF_cf = get_VF_cf(
	w = 1.0, Y = baseline_aggregates_goods.total_output,
	type = "C", 
	tax = vat_tax, 
	param = primitive_param, 
	guess_exp_VF = (1/(1-beta))*profits_NC_baseline,
	verbose = true)
# They are indeed the same! 
sum(abs.(baseline_VF.VF .- test_VF_cf.VF))
sum(abs.(baseline_VF.exit_proba .- test_VF_cf.exit_proba))
## Step 2 (test): baseline stationary distribution will then also look exactly the same! 

## Test 2: Recover baseline equilibrium as validation: Converges directly = correct 
baseline_eq_SS_goods = find_equilibrium_cf_goods(
	guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
	type = "C", aggr_L = aggregate_labor_supply,
	param = primitive_param, 
	update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 100)

#endregion 

#region ### 4. Aggr costs of Political Connections: Baseline counterfactual w/ lump-sum rebates (SS + Transition) ### 

## Find baseline cf equilibrium (in SS)
baseline_cf_eq_SS_goods = find_equilibrium_cf_goods(
	guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
	param = primitive_param, aggr_L = aggregate_labor_supply,
	update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 100, crit = 1e-12)

## Can compute aggregate gains: 
# total output is about 1.1% lower!
(baseline_cf_eq_SS_goods.Y / baseline_aggregates_goods.total_output) - 1
# This is despite increase in mass of firms (about 3.35%) 
baseline_cf_eq_SS_goods.mass - 1

# Value added output is higher! (2.24% higher, 2.61% higher when taking out rent-seeking) 
(baseline_cf_eq_SS_goods.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates)) - 1
(baseline_cf_eq_SS_goods.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) - 1

# wage is about 3.78% lower in new SS equilibrium 
baseline_cf_eq_SS_goods.w - 1 

## Get baseline aggregates 
baseline_cf_eq_SS_goods_aggregates = compute_aggregates_noC_goods(
	w = baseline_cf_eq_SS_goods.w, 
	distrib = baseline_cf_eq_SS_goods.distrib_results.SS_distrib, 
	mass_firms = baseline_cf_eq_SS_goods.mass, 
	mass_entry = baseline_cf_eq_SS_goods.distrib_results.mass_entry, 
	exit_proba = baseline_cf_eq_SS_goods.VF_results.exit_proba, 
	within_period_choices = compute_profits_NC(w = baseline_cf_eq_SS_goods.w, Y = baseline_cf_eq_SS_goods.Y, param = baseline_param, z_grid = z_grid_baseline), 
	VF_results = baseline_cf_eq_SS_goods.VF_results, 
	param = baseline_param)

# same results (sanity check!) 
(baseline_cf_eq_SS_goods_aggregates.total_output / baseline_aggregates_goods.total_output) - 1 

# Main consumption effects: 7.73% higher consumption without connections (9.2% without assets)
((rental_rate - delta - 1)*initial_assets + baseline_cf_eq_SS_goods_aggregates.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income_noprofits) - 1 
(baseline_cf_eq_SS_goods_aggregates.total_HH_income_noprofits) / (baseline_aggregates_goods.total_HH_income_noprofits) - 1 
# If include profits: baseline costs are around 5.37% (6.27% without assets)
((rental_rate - delta - 1)*initial_assets + baseline_cf_eq_SS_goods_aggregates.total_HH_income) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income) - 1 
(baseline_cf_eq_SS_goods_aggregates.total_HH_income / baseline_aggregates_goods.total_HH_income) - 1 
# How do profits change? Decline by 11.90% 
(baseline_cf_eq_SS_goods_aggregates.total_net_profits / baseline_aggregates_goods.total_net_profits) - 1 
# How do govt transfers change? increase by 48.26% 
(baseline_cf_eq_SS_goods_aggregates.tax_revenues / baseline_aggregates_goods.net_govt_transfers) - 1 


#### Compute transition ####

# solve for baseline transition path 
baseline_transition_path_cf_goods = find_transition_path_cf_goods(
    guess_w_path = collect(range(1.0, baseline_cf_eq_SS_goods.w, length = 150)), 
    guess_Y_path = collect(range(baseline_aggregates_goods.total_output, baseline_cf_eq_SS_goods.Y, length = 150)),
	param = baseline_param,
	VF_results_end = baseline_cf_eq_SS_goods.VF_results, 
	aggr_L = aggregate_labor_supply, 
	profits_object_end = compute_profits_NC(w = baseline_cf_eq_SS_goods.w, Y = baseline_cf_eq_SS_goods.Y, param = baseline_param),
    starting_distribution = baseline_stationary_distrib.SS_distrib, 
    max_iter = 150, max_iter_w = 100, crit = 1e-6, 
    update_param_w = 0.5, update_param_Y = 0.9
    )

# Wage convergence is monotonic 
Plots.plot(collect(1:1:150), baseline_transition_path_cf_goods.w_path ./ baseline_cf_eq_SS_goods.w, legend=:outertopright, label = "Wage path")

# Similarly: output also converges monotonic 
Plots.plot(collect(1:1:150), baseline_transition_path_cf_goods.Y ./ baseline_aggregates_goods.total_output, legend=:outertopright, label = "Output")

# check entrant path: first increase in entry, then decline  
Plots.plot(collect(1:1:150), baseline_transition_path_cf_goods.mass_entrant_path, legend=:outertopright)

# exit path: increasing monotonically
Plots.plot(collect(1:1:150), baseline_transition_path_cf_goods.mass_exit_path, legend=:outertopright)

## Check that transition indeed finally hits new SS (in terms of w,Y,mass,distribution): Correct!! 
baseline_cf_eq_SS_goods.w / baseline_transition_path_cf_goods.w_path[150] # wage is the same! 
baseline_cf_eq_SS_goods.Y / baseline_transition_path_cf_goods.Y[150] # output is the same! 
baseline_cf_eq_SS_goods.mass / sum(baseline_transition_path_cf_goods.distribution_path[150]) # mass is the same! 
sum(abs.((baseline_cf_eq_SS_goods.distrib_results.SS_distrib.*baseline_cf_eq_SS_goods.mass) .- baseline_transition_path_cf_goods.distribution_path[150])) # distribution is also the same!! 

#### Plot!! #### 

baseline_transition_path_cf_goods_mass = ones(150)
for year in 1:150 
	baseline_transition_path_cf_goods_mass[year] = sum(baseline_transition_path_cf_goods.distribution_path[year])
end 

# plot joint plot with transition results: mass of entrants, exits, total mass, wage & output 
Plots.hline([0.0], 
	linestyle=:dash, color=:gray, lw = 1.5, 
	label = "", 
	xlabel = "Years of transition", 
	ylabel = "% deviations from baseline", 
	dpi = 200,
	thickness_scaling = 1.3)
Plots.plot!(collect(0:1:150), [1.0;baseline_transition_path_cf_goods.w_path] .- 1, label = "Wage", lw = 2.0)
Plots.plot!(collect(0:1:150), [baseline_aggregates_goods.total_output; baseline_transition_path_cf_goods.Y] ./ baseline_aggregates_goods.total_output .- 1, label = "Gross output", lw = 2.0)
Plots.plot!(collect(0:1:150), [baseline_stationary_distrib.mass_entry; baseline_transition_path_cf_goods.mass_entrant_path] ./ baseline_stationary_distrib.mass_entry .- 1, label = "Mass entrants", lw = 2.0)
Plots.plot!(collect(0:1:150), [baseline_stationary_distrib.mass_entry; baseline_transition_path_cf_goods.mass_exit_path] ./ baseline_stationary_distrib.mass_entry .- 1, label = "Mass exits", lw = 2.0)
Plots.plot!(collect(0:1:150), [1.0; baseline_transition_path_cf_goods_mass] .- 1, label = "Mass total", lw = 2.0)
Plots.plot!(xlimits=(0,30), ylimits = (-0.06,0.35))
yticks_vals = [-0.05,0.0,0.05,0.1,0.2,0.3] #-0.04:0.01:0.02  # Choose tick positions
yticks_labels = ["$(round(Int, 100 * y))%" for y in yticks_vals]
Plots.yticks!(yticks_vals, yticks_labels)
savefig(joinpath(dout,"results_baseline_cf_transition.png"))


#endregion 

#region ### 5. GE effects with constant subsidy rate ### 

### Single constant rate 

# solve for equilibrium with fixed subsidy across connected firms (and assume no costs!) 
baseline_cf_Cfix = find_equilibrium_cf_goods(
	guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
	type = "C-fix", 
	subsidy = mean(sample_subsidies), ## Set to average subsidy rate instead 
	param = baseline_param, 
	aggr_L = aggregate_labor_supply,
	update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 100)

## Can compute welfare gains: 
# total output is about 0.265% higher
(baseline_cf_Cfix.Y / baseline_aggregates_goods.total_output) - 1
# This is despite higher overall mass of firms (about 1.5% more firms in new equilibrium)
baseline_cf_Cfix.mass - 1

# Value added output is higher (1.6% higher, 1.96% higher when taking out rent-seeking) 
(baseline_cf_Cfix.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates)) - 1
(baseline_cf_Cfix.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) - 1

# wage is about 0.8% lower in new SS equilibrium 
baseline_cf_Cfix.w - 1 

## get aggregates! 
baseline_cf_Cfix_aggregates = compute_aggregates_goods(
	w = baseline_cf_Cfix.w, 
	distrib = baseline_cf_Cfix.distrib_results.SS_distrib,
	mass_firms = baseline_cf_Cfix.mass, 
	mass_entry = baseline_cf_Cfix.distrib_results.mass_entry, 
	within_period_choices = (
		expected_revenue = baseline_cf_Cfix.VF_results.revenue, 
		expected_revenue_output = baseline_cf_Cfix.VF_results.output_aggr, 
		expected_m_R_C = zeros(1000), 
		expected_profits = baseline_cf_Cfix.VF_results.exp_profits, 
		expected_subsidies_C = (compute_profits_grid_cf_fix_subsidy(
			w = baseline_cf_Cfix.w, Y = baseline_cf_Cfix.Y, 
			subsidy = mean(sample_subsidies), tax = nothing, param = baseline_param, z_grid = z_grid_baseline)).subsidy_amount
	),
	tax = vat_tax, param = baseline_param)	

# same results 
(baseline_cf_Cfix_aggregates.total_output / baseline_aggregates_goods.total_output) - 1 

# Main consumption effects: 3.79% higher consumption without connections (4.5% without assets)
((rental_rate - delta - 1)*initial_assets + baseline_cf_Cfix_aggregates.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income_noprofits) - 1 
(baseline_cf_Cfix_aggregates.total_HH_income_noprofits) / (baseline_aggregates_goods.total_HH_income_noprofits) - 1 

# If include profits: baseline costs are around 3.1-3.6% 
((rental_rate - delta - 1)*initial_assets + baseline_cf_Cfix_aggregates.total_HH_income) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income) - 1 
(baseline_cf_Cfix_aggregates.total_HH_income / baseline_aggregates_goods.total_HH_income) - 1 

# How do profits change? Decrease by 2.02%
(baseline_cf_Cfix_aggregates.total_net_profits / baseline_aggregates_goods.total_net_profits) - 1 

# How do govt transfers change? 20.54% higher 
(baseline_cf_Cfix_aggregates.net_govt_transfers / baseline_aggregates_goods.net_govt_transfers) - 1 


### Solve for "optimal" level of subsidies to connected firms ###

# create grid of average subsidy rates 
level_subsidy_grid = collect( range(0.0, stop = 2*mean(sample_subsidies), length = 100) )

# save output results in vector 
output_results_level_subsidy_grid = zeros(100)

# loop through equilibria (takes a few moments; less than 1min on my MacBook)
for i in 1:length(level_subsidy_grid)
	# show progress
	println("Solving entry ", i)

	# solve result 
	output_result_subsidy_grid = find_equilibrium_cf_goods(
		guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
		type = "C-fix", 
		subsidy = level_subsidy_grid[i],
		param = baseline_param, 
		aggr_L = aggregate_labor_supply,
		update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 100)

	# save output result 
	output_results_level_subsidy_grid[i] = (output_result_subsidy_grid.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) - 1
end 

## print results and see optimum (Baseline level is not far from optimum!) 
Plots.plot(
	level_subsidy_grid, output_results_level_subsidy_grid, 
	primary=false, color = "red", 
	xlabel = "Fixed subsidy rate",
	ylabel = "Aggregate Output relative to Baseline", dpi = 200)
vline!([mean(sample_subsidies)], label="Baseline average subsidy")
vline!([level_subsidy_grid[argmax(output_results_level_subsidy_grid)]], label="Optimal subsidy level", color = "black", linestyle = :dash)
savefig(joinpath(dout,"results_optimal_fixed_subsidy_rate.png"))

# which one is optimal? 5.2 rate is optimal! 
argmax(output_results_level_subsidy_grid) # number 9
level_subsidy_grid[argmax(output_results_level_subsidy_grid)]
maximum(output_results_level_subsidy_grid) # maximum gains would be close to 3.1%

#endregion 

#region ### 6. Aggr costs of Political Connections with alternative DRS technology ###

## Load data/parameters ## 

# z_star grid & transition matrix 
DRS_z_star_grid  = DataFrame(CSV.File(joinpath(ddat,"DRS_z_star_grid.csv"); header=1))[:,1]
DRS_transition_z    = Array(DataFrame(CSV.File(joinpath(ddat,"DRS_transition_matrix_z.csv"); header=1)))

# load parameters 
DRS_parameters_df = DataFrame(CSV.File(joinpath(ddat,"DRS_parameters.csv"); header=1))

# get baseline unconditional z_star distribution 
DRS_z_star_uncond_distrib = DataFrame(CSV.File(joinpath(ddat,"DRS_z_star_uncond_distrib.csv"); header=1)).share

# define baseline parameters 
DRS_param = define_param(
	level_fcost = DRS_parameters_df.level_fcost[1], 
	scale_fcost = DRS_parameters_df.scale_fcost[1], 
	connect_theta = DRS_parameters_df.theta[1], 
	connect_cost_level = 0.0, 
	connect_cost_elasticity = 0.0, 
	connect_fixed_cost = 0.0, 
	connect_rho = DRS_parameters_df.rho[1], 
	connect_proba_c = baseline_param.proba_c,
	proba_c = baseline_param.proba_c,  
	alpha_eps = DRS_parameters_df.alpha_eps[1], 
	beta_eps = DRS_parameters_df.beta_eps[1], 
	variance_eps_z = DRS_parameters_df.variance_eps_z[1], 
	sigma = baseline_param.sigma 
	) 

# Step 1: Get expected profits by state 
DRS_results_profits_baseline = compute_profits_grid_baseline_DRS(z_star_grid = DRS_z_star_grid, n_eps = 500, param = DRS_param)

# Step 2: find baseline VF
DRS_baseline_VF = get_VF_baseline_DRS(
	param = DRS_param, 
	guess_exp_VF = (1/(1-beta))*profits_NC_baseline, 
	exp_profits = DRS_results_profits_baseline.expected_profits, 
	verbose = true)

## This is in line with R
# can also back out implied entry cost (which is in line with free entry condition and w = 1) 
const DRS_entry_cost = sum(DRS_baseline_VF.VF .* DRS_z_star_uncond_distrib) # Unconditional value of entry 

# Step 3: Find baseline stationary distribution
DRS_baseline_stationary_distrib = get_stationary_distrib_DRS(
	initial_distrib = DRS_z_star_uncond_distrib, 
	entry_distrib = DRS_z_star_uncond_distrib,
	exit_proba = DRS_baseline_VF.exit_proba, verbose = true, crit = 1e-12)

# Construct aggregates needed throughout 
DRS_baseline_aggregates_goods = compute_aggregates_goods(
		w = 1.0, distrib = DRS_baseline_stationary_distrib.SS_distrib,
		mass_firms = 1.0, # normalized in baseline economy 
		mass_entry = DRS_baseline_stationary_distrib.mass_entry, 
		within_period_choices = DRS_results_profits_baseline, 
		entry_cost = DRS_entry_cost, 
		tax = vat_tax, param = DRS_param)

# about 0.5% of total output spend on rent-seeking (That is more than 2x as large as in baseline!)  
DRS_baseline_aggregates_goods.total_rent_seeking ./ DRS_baseline_aggregates_goods.total_output
DRS_baseline_aggregates_goods.total_rent_seeking / (DRS_baseline_aggregates_goods.productive_intermediates + DRS_baseline_aggregates_goods.total_rent_seeking)

# how much tax revenue spent on subsidizing connected firms? 42% (similar to baseline) 
DRS_baseline_aggregates_goods.total_subsidies / DRS_baseline_aggregates_goods.gross_tax_revenues
DRS_baseline_aggregates_goods.total_subsidies / DRS_baseline_aggregates_goods.net_govt_transfers # share over net govt transfers 
(DRS_baseline_aggregates_goods.total_subsidies + DRS_baseline_aggregates_goods.net_govt_transfers) / DRS_baseline_aggregates_goods.net_govt_transfers # share over net govt transfers 

# How about total subsidies? Very similar
DRS_baseline_aggregates_goods.total_subsidies / baseline_aggregates_goods.total_subsidies

# Net govt transfers? Net govt transfers quite a bit lower
DRS_baseline_aggregates_goods.net_govt_transfers / baseline_aggregates_goods.net_govt_transfers

# Subsidies over total transfers? 72.7 vs. 53.2. So subsidies are much more important in DRS economy! 
DRS_baseline_aggregates_goods.total_subsidies / DRS_baseline_aggregates_goods.net_govt_transfers
baseline_aggregates_goods.total_subsidies / baseline_aggregates_goods.net_govt_transfers

# How about profits? They are quite a bit lower in DRS model! Hence also lower profit tax income 
DRS_baseline_aggregates_goods.total_net_profits / baseline_aggregates_goods.total_net_profits

# what is baseline govt budget spent on HHs (or other development objectives)? 
const DRS_baseline_govt_spending = DRS_baseline_aggregates_goods.net_govt_transfers 

# save aggregate labor supply (the labor demand that clears labor market in baseline SS) (slightly different from Table 2)
const DRS_aggregate_labor_supply = DRS_baseline_aggregates_goods.total_labor_demand
const DRS_aggregate_labor_supply_goods = DRS_baseline_aggregates_goods.total_labor_demand

# how important are indiv parts in HH income? (Similar to baseline) Profits (11.8%), net govt transfers (19.91%), labor income (68.29%) 
DRS_baseline_aggregates_goods.total_net_profits / DRS_baseline_aggregates_goods.total_HH_income 
DRS_baseline_aggregates_goods.net_govt_transfers / DRS_baseline_aggregates_goods.total_HH_income
DRS_baseline_aggregates_goods.total_labor_demand / DRS_baseline_aggregates_goods.total_HH_income

# derive baseline z_grid 
DRS_x_bar_baseline = (DRS_baseline_aggregates_goods.total_output)^(1/DRS_param.sigma)
DRS_z_grid_baseline = (DRS_z_star_grid ./ DRS_x_bar_baseline).^(DRS_param.sigma/(DRS_param.sigma-1))

# Need initial assets first: set to K_SS in baseline equilibrium (can vary this later!) 
DRS_initial_assets = DRS_baseline_aggregates_goods.productive_capital 
DRS_baseline_HH_consumption = (rental_rate - delta - 1)*DRS_initial_assets + DRS_baseline_aggregates_goods.total_HH_income

## Back out all primitive parameters 

DRS_primitive_param = define_param(
	DRS_param, 
	beta_eps = DRS_param.beta_eps*((DRS_param.sigma - 1)/DRS_param.sigma), 
	alpha_eps = DRS_param.alpha_eps + DRS_param.beta_eps*(1/DRS_param.sigma)*log(DRS_baseline_aggregates_goods.total_output)
)

#### Next: Compute Aggregate costs of political connections 

## Find baseline cf equilibrium (in SS)
DRS_baseline_cf_eq_SS_goods = find_equilibrium_cf_DRS_goods(
		guess_w = 1.0, guess_Y = DRS_baseline_aggregates_goods.total_output, 
		param = DRS_primitive_param, aggr_L = DRS_aggregate_labor_supply_goods,
		update_param_w = 0.5, update_param_Y = 0.9, max_iter = 250, max_iter_w = 250)


## Can compute welfare gains: gross output is about 1.4% higher!
(DRS_baseline_cf_eq_SS_goods.Y / DRS_baseline_aggregates_goods.total_output) - 1 # 

# This is partly because of more firms! (about 4.73% more firms in new equilibrium)
DRS_baseline_cf_eq_SS_goods.mass - 1 

# Value added output is a lot higher (6.85% higher when taking out rent-seeking) 
(DRS_baseline_cf_eq_SS_goods.output_VA / (DRS_baseline_aggregates_goods.total_output - DRS_baseline_aggregates_goods.productive_intermediates)) - 1
(DRS_baseline_cf_eq_SS_goods.output_VA / (DRS_baseline_aggregates_goods.total_output - DRS_baseline_aggregates_goods.productive_intermediates - DRS_baseline_aggregates_goods.total_rent_seeking)) - 1

# wage is 1.86% lower in new SS equilibrium 
DRS_baseline_cf_eq_SS_goods.w - 1 

## Get aggregates 
DRS_baseline_cf_aggregates_goods = compute_aggregates_noC_goods(
	w = DRS_baseline_cf_eq_SS_goods.w, 
	distrib = DRS_baseline_cf_eq_SS_goods.distrib_results.SS_distrib, 
	mass_firms = DRS_baseline_cf_eq_SS_goods.mass, 
	mass_entry = DRS_baseline_cf_eq_SS_goods.distrib_results.mass_entry, 
	exit_proba = DRS_baseline_cf_eq_SS_goods.VF_results.exit_proba, 
	within_period_choices = compute_profits_NC(w = DRS_baseline_cf_eq_SS_goods.w, Y = DRS_baseline_cf_eq_SS_goods.Y, param = DRS_param, z_grid = DRS_z_grid_baseline), 
	VF_results = DRS_baseline_cf_eq_SS_goods.VF_results, 
	entry_cost = DRS_entry_cost,
	param = DRS_param)

# same results 
(DRS_baseline_cf_aggregates_goods.total_output / DRS_baseline_aggregates_goods.total_output) - 1 

# Main consumption effects: +12.50% (+15.0% without assets)
((rental_rate - delta - 1)*DRS_initial_assets + DRS_baseline_cf_aggregates_goods.total_HH_income_noprofits) / ((rental_rate - delta - 1)*DRS_initial_assets + DRS_baseline_aggregates_goods.total_HH_income_noprofits) - 1 
(DRS_baseline_cf_aggregates_goods.total_HH_income_noprofits) / (DRS_baseline_aggregates_goods.total_HH_income_noprofits) - 1 

# If include profits: baseline costs are around 10.66%
((rental_rate - delta - 1)*DRS_initial_assets + DRS_baseline_cf_aggregates_goods.total_HH_income) / ((rental_rate - delta - 1)*DRS_initial_assets + DRS_baseline_aggregates_goods.total_HH_income) - 1 
(DRS_baseline_cf_aggregates_goods.total_HH_income / DRS_baseline_aggregates_goods.total_HH_income) - 1 

# How do profits change? -5.80% 
(DRS_baseline_cf_aggregates_goods.total_net_profits / DRS_baseline_aggregates_goods.total_net_profits) - 1 

# How do govt transfers change? +72.84%
(DRS_baseline_cf_aggregates_goods.tax_revenues / DRS_baseline_aggregates_goods.net_govt_transfers) - 1 


#endregion 

#region ### 7. GE effects without entry/exit ###

# Need find_equilibrium_cf() that takes distribution and mass as given
baseline_cf_noEE = find_equilibrium_cf_fixdistrib(
	guess_w = 1.0, 
	distrib = baseline_stationary_distrib.SS_distrib, # keep baseline SS distribution fixed 
	mass = 1.0, # also keep initial mass fixed 
	aggr_L = aggregate_labor_supply, 
	param = baseline_param, update_param_w = 0.5, crit = 1e-6, max_iter = 500, verbose = true)

# Output: Gross output is 2.8% lower. Compared to 4.6% lower w/ EE 
baseline_cf_noEE.Y / baseline_aggregates_goods.total_output - 1

# mass is fixed now! 

# Wage effect is: -5.4%
baseline_cf_noEE.w - 1 # wage is about 5.4% lower 

# How about value added? VA output effect is slightly positive w/out EE -- around 0.85%
baseline_cf_noEE.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking) - 1


## Get baseline aggregates 
baseline_cf_noEE_aggregates = compute_aggregates_noC_noEE(
	w = baseline_cf_noEE.w, 
	distrib = baseline_stationary_distrib.SS_distrib, # keep baseline SS distribution fixed 
	mass_firms = 1.0, # also keep initial mass fixed 
	within_period_choices = compute_profits_NC(w = baseline_cf_noEE.w, Y = baseline_cf_noEE.Y, param = baseline_param, z_grid = z_grid_baseline))

# same results: -2.8%
(baseline_cf_noEE_aggregates.total_output / baseline_aggregates_goods.total_output) - 1 

# Main consumption effects: 6.16% higher consumption without connections 
((rental_rate - delta - 1)*initial_assets + baseline_cf_noEE_aggregates.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income_noprofits) - 1 

# If include profits: costs are still 2.5% 
((rental_rate - delta - 1)*initial_assets + baseline_cf_noEE_aggregates.total_HH_income) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income + baseline_stationary_distrib.mass_entry*entry_cost) - 1 

# How do profits change? Decline by 4.52% (Need to add back entry costs to baseline, otherwise not comparable) 
(baseline_cf_noEE_aggregates.total_net_profits / (baseline_aggregates_goods.total_net_profits + baseline_aggregates_goods.total_entry_costs) ) - 1 

# How do govt transfers change? increase by 45.72% 
(baseline_cf_noEE_aggregates.tax_revenues / baseline_aggregates_goods.net_govt_transfers) - 1 

## How does EE affect selection of firms & avg productivity? 
sum(baseline_stationary_distrib.SS_distrib .* z_grid_baseline) # baseline avg productivity 
sum(baseline_cf_eq_SS_goods.distrib_results.SS_distrib .* z_grid_baseline) # indeed slightly smaller avg productivity 

#endregion 

#region ### 8. Aggr costs of Political Connections with lower taxes ###

#### Tax rate CF: Shut down political connections and adapt tax rate to keep G constant #### 

## Solve for SS counterfactual 

tax_cf_eq_SS_noEE = find_equilibrium_cf_tax_goods_fixdistrib(
	guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
	guess_tax = 0.075, 
	distrib = baseline_stationary_distrib.SS_distrib, 
	mass = 1.0, 
	aggr_L = aggregate_labor_supply, 
	aggr_T = baseline_aggregates_goods.net_govt_transfers, 
	param = baseline_param, 
	type = "NC", 
	update_param_w = 0.5, 
	update_param_Y = 0.75, 
	update_param_tax = 0.75, 
	crit = 1e-16, max_iter = 500, verbose = false)
# 

# optimal tax rate is 4.1% now (about half!)
tax_cf_eq_SS_noEE.tax 

## Can compute welfare gains: 
# gross output is now higher! About 1.95% higher 
tax_cf_eq_SS_noEE.Y / baseline_aggregates_goods.total_output - 1

# by construction no mass changes 

# Value added output gains are larger: about 5.78% 
(tax_cf_eq_SS_noEE.output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) - 1

# wage is now also higher in new SS equilibrium 
tax_cf_eq_SS_noEE.w - 1 # about 5.39% 

## Get aggregates 
tax_cf_eq_aggregates_noEE = compute_aggregates_noC_goods_noEE(
	w = tax_cf_eq_SS_noEE.w, 
	distrib = baseline_stationary_distrib.SS_distrib,  
	mass_firms = 1.0, 
	within_period_choices = compute_profits_NC(w = tax_cf_eq_SS_noEE.w, Y = tax_cf_eq_SS_noEE.Y, tax=tax_cf_eq_SS_noEE.tax, param = baseline_param, z_grid = z_grid_baseline),  
	tax = tax_cf_eq_SS_noEE.tax)

# same results 
(tax_cf_eq_aggregates_noEE.total_output / tax_cf_eq_SS_noEE.Y) - 1 

## Compute consumption costs next!! 

# Main consumption effects: 3.58% higher consumption without connections (smaller than in baseline!)
((rental_rate - delta - 1)*initial_assets + tax_cf_eq_aggregates_noEE.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income_noprofits) - 1 

## Why smaller now? 

# If include profits: Only about 2.3% larger 
((rental_rate - delta - 1)*initial_assets + tax_cf_eq_aggregates_noEE.total_HH_income) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income + entry_cost*baseline_stationary_distrib.mass_entry) - 1 

# How do profits change? Almost stay unchanged: 0.14%
(tax_cf_eq_aggregates_noEE.total_net_profits / (baseline_aggregates_goods.total_net_profits + entry_cost*baseline_stationary_distrib.mass_entry)) - 1 

# check tax revenues: T stays indeed constant!
tax_cf_eq_aggregates_noEE.tax_revenues / baseline_aggregates_goods.net_govt_transfers

# what is the share of tax revenues in total HH income? About 24.5% 
tax_cf_eq_aggregates_noEE.tax_revenues / tax_cf_eq_aggregates_noEE.total_HH_income_noprofits

### So consumption effect now smaller because T stays fixed & positive wage effects dont compensate 

#endregion 

#region ### 9. Aggregate benefits of increased auditing ###

# create grid of auditing costs 
auditing_cost_grid = [range(0.1, stop = 2.0, length = 20);[2.5,3.0,5.0]] 

# derive auditing price based on baseline government budget share on auditing 
const auditing_price = baseline_aggregates_goods.gross_tax_revenues * (173 / 32834) / baseline_param.connect_cost_level
const auditing_price_UB = 10*auditing_price

# save results in vector 
cf_results_auditing_cost_grid = repeat([baseline_equilibrium],length(auditing_cost_grid))

# additionally compute aggregates and a bunch of other objects 
aggregate_results_auditing_cost_grid = repeat([baseline_aggregates_goods],length(auditing_cost_grid))
mass_results_auditing_cost_grid = zeros(length(auditing_cost_grid))
output_auditing_cost_grid = zeros(length(auditing_cost_grid))
VA_output_auditing_cost_grid = zeros(length(auditing_cost_grid))
wage_auditing_cost_grid = zeros(length(auditing_cost_grid))
consumption_auditing_cost_grid = zeros(length(auditing_cost_grid))
consumption_auditing_cost_grid_UB = zeros(length(auditing_cost_grid))

# loop through equilibria (this takes a couple of minutes)
Threads.@threads for i in eachindex(auditing_cost_grid) # 

	# show progress
	println("Solving entry ", i, " from ", length(auditing_cost_grid))

	# define parameters
	auditing_param = define_param(primitive_param, connect_cost_level = primitive_param.connect_cost_level*auditing_cost_grid[i])

	# solve result & save in list 
	cf_results_auditing_cost_grid[i] = find_equilibrium_cf_goods(
		guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
		type = "C", 
		aggr_L = aggregate_labor_supply,
		param = auditing_param, 
		update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 100)

	# save results 
	output_auditing_cost_grid[i] = cf_results_auditing_cost_grid[i].Y
	wage_auditing_cost_grid[i] = cf_results_auditing_cost_grid[i].w
	mass_results_auditing_cost_grid[i] = cf_results_auditing_cost_grid[i].mass 

	# get profit results 
	auditing_profit_results = compute_profits_grid_cf(
		w = cf_results_auditing_cost_grid[i].w, 
		Y = cf_results_auditing_cost_grid[i].Y, 
		tax = vat_tax, n_eps = 500, 
		param = auditing_param,  
		z_grid = z_grid_baseline)

	# now compute aggregates (to get correct measure of VA output)
	aggregate_results_auditing_cost_grid[i] = compute_aggregates_goods(
		w = cf_results_auditing_cost_grid[i].w, 
		distrib = cf_results_auditing_cost_grid[i].distrib_results.SS_distrib,
		mass_firms = cf_results_auditing_cost_grid[i].mass,
		mass_entry = cf_results_auditing_cost_grid[i].distrib_results.mass_entry, 
		within_period_choices = auditing_profit_results, 
		tax = vat_tax, param = auditing_param)

	# save value added output & consumption 
	VA_output_auditing_cost_grid[i] = aggregate_results_auditing_cost_grid[i].total_output - aggregate_results_auditing_cost_grid[i].productive_intermediates - aggregate_results_auditing_cost_grid[i].total_rent_seeking
	consumption_auditing_cost_grid[i] = (rental_rate - delta - 1)*initial_assets + aggregate_results_auditing_cost_grid[i].total_HH_income_noprofits - (auditing_price * auditing_param.connect_cost_level)
	consumption_auditing_cost_grid_UB[i] = (rental_rate - delta - 1)*initial_assets + aggregate_results_auditing_cost_grid[i].total_HH_income_noprofits - (auditing_price_UB * auditing_param.connect_cost_level)
end 

# compute relative costs to baseline 
VA_output_auditing_cost_grid_relative = (VA_output_auditing_cost_grid ./ (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) .- 1.0 
consumption_auditing_cost_grid_UB_relative = (consumption_auditing_cost_grid_UB ./ consumption_auditing_cost_grid_UB[10]) .- 1.0
consumption_auditing_cost_grid_relative = (consumption_auditing_cost_grid ./ consumption_auditing_cost_grid[10]) .- 1.0

# plot VA output against auditing rate. Why is this strictly declining? Aren't level of subsidies getting too high eventually? 
Plots.hline([0.0], linestyle=:dash, color=:gray, lw = 1.5, label = "")
Plots.plot!(auditing_cost_grid, consumption_auditing_cost_grid_relative, label = "Main (Consumption)", color = "red", lw = 3.0)
Plots.plot!(auditing_cost_grid, consumption_auditing_cost_grid_UB_relative, label = "Lower bound (Consumption)", color = "red", lw = 2.0, linestyle = :dashdot)
Plots.plot!(
	auditing_cost_grid, VA_output_auditing_cost_grid_relative, 
	label = "VA Output (for main)",
	xlabel = "Level of public oversight c (norm.)", 
	ylabel = "% deviations from baseline", 
	lw = 1.0,
	dpi = 200, color = "blue",
	linestyle = :dashdotdot, 
	thickness_scaling = 1.3)
Plots.vline!([1.0], label = "Baseline", color=:black, linestyle=:dash)
Plots.plot!(xlimits=(0.1,3.5))
yticks_vals = [-0.2,-0.15,-0.1,-0.05,0.0,0.02]  # Choose tick positions
yticks_labels = ["$(round(Int, 100 * y))%" for y in yticks_vals]
Plots.yticks!(yticks_vals, yticks_labels)
savefig(joinpath(dout,"results_optimal_auditing.png"))
savefig(joinpath(dout,"Figure5.pdf"))

## What is doubling the auditing budget in terms of GDP? Additional 0.095% of GDP 
(baseline_aggregates_goods.gross_tax_revenues * (173 / 32834)) / VA_output_auditing_cost_grid[10] 

## Output & Consumption benefits from tripling auditing: 65% for output, 45% for consumption (reported in paper!) 
VA_output_auditing_cost_grid_relative[22]/0.0261 # fraction of output costs from connections that auditing can solve 
consumption_auditing_cost_grid_relative[22]/0.0773 # fraction of output costs from connections that auditing can solve 

#endregion 

#region ### 10. Extension 1: Wedges ###

# load wedge functions & data 
include("wedge_extension.jl")

# Create wedge_parameters 
wedge_param_baseline = define_param_wedges()

# calculate profits/revenue for C & NC for baseline (wedge) economy
results_profits_baseline_wedges = compute_profits_grid_wedges_baseline(
	z_star_grid = z_star_grid_wedges, n_draws = 500, param = wedge_param_baseline)

# For labor share:   
Plots.plot(collect(1:length(z_star_grid_wedges)), results_profits_baseline_wedges.expected_labor ./ results_profits_baseline_wedges.expected_revenue, label = "Variation in labor share")
Plots.hline!([beta_gross*(1-vat_tax)], label="Median labor share")

# For capital share: This looks like well-behaved!! 
Plots.plot(collect(1:length(z_star_grid_wedges)), ((rental_rate - 1.0) .* results_profits_baseline_wedges.expected_capital ./ results_profits_baseline_wedges.expected_revenue), label = "Variation in capital share")
Plots.hline!([alpha_gross*(1-vat_tax)], label="Median capital share")

# Average subsidy rate? 41.6%
sum(results_profits_baseline_wedges.subsidy_variation .* results_profits_baseline_wedges.epsilon_proba_variation) / sum(results_profits_baseline_wedges.epsilon_proba_variation)

# Average rent-seeking share is about 4.96% 
sum(results_profits_baseline_wedges.rent_seeking_share_variation .* results_profits_baseline_wedges.epsilon_proba_variation) / sum(results_profits_baseline_wedges.epsilon_proba_variation)

# Get aggregates for case of wedges 
aggregates_baseline_wedges = compute_aggregates_wedges(
	w = 1.0, distrib = z_star_distrib_wedges, mass_firms = 1.0, 
	within_period_choices = results_profits_baseline_wedges, tax = nothing, 
	param = wedge_param_baseline)

# about 0.23-0.4% of total output spend on rent-seeking
aggregates_baseline_wedges.total_rent_seeking ./ aggregates_baseline_wedges.total_output
aggregates_baseline_wedges.total_rent_seeking / (aggregates_baseline_wedges.productive_intermediates + aggregates_baseline_wedges.total_rent_seeking)

# how much tax revenue spent on subsidizing connected firms? 39.5% (quite a bit lower than for baseline!) 
aggregates_baseline_wedges.total_subsidies / aggregates_baseline_wedges.gross_tax_revenues

# get aggregate labor supply from wedge case 
aggregates_baseline_wedges.total_labor_demand

# initial assets in wedge case 
initial_assets_wedges = aggregates_baseline_wedges.productive_capital

#endregion 

#region ### 10a) CF1 (No subsidies, no differential wedges) ###

#### Solve main cf equilibrium without connections (but with wedges) ####

# get z_grid_wedges (taking out role of Y)
z_grid_wedges = (z_star_grid_wedges ./ (aggregates_baseline_wedges.total_output^(1/wedge_param_baseline.sigma))).^(wedge_param_baseline.sigma/(wedge_param_baseline.sigma - 1))
mean_log_z_wedges = (wedge_param_baseline.mean_log_z_star - (1/wedge_param_baseline.sigma)*log(aggregates_baseline_wedges.total_output))*(wedge_param_baseline.sigma/(wedge_param_baseline.sigma - 1))

# Now solve for counterfactual 
cf_baseline_wedges = find_equilibrium_cf_wedges(
	guess_w = 1.0, 
	guess_Y = aggregates_baseline_wedges.total_output, 
	distrib = z_star_distrib_wedges, 
	mass = 1.0, # fix distribution 
	aggr_labor_supply = aggregates_baseline_wedges.total_labor_demand, 
	param = wedge_param_baseline, 
	z_grid = z_grid_wedges, 
	mean_log_z = mean_log_z_wedges, 
	n_draws = 500, 
	update_param_w = 0.5, update_param_Y = 0.5, crit = 1e-6, max_iter = 500, verbose = false)

# Get counterfactual costs!! 

## Gross output is slightly lower: about 2.55%!
cf_baseline_wedges.Y / aggregates_baseline_wedges.total_output - 1

# wages? -4.89% 
cf_baseline_wedges.w - 1

# value added output costs: 1.74% higher without connections
cf_baseline_wedges.aggregates.value_added_output / aggregates_baseline_wedges.total_value_added_output - 1 

# Main consumption effects: 8.54% higher consumption without connections (10.24% without capital) 
((rental_rate - delta - 1)*initial_assets_wedges + cf_baseline_wedges.aggregates.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets_wedges + aggregates_baseline_wedges.total_HH_income_noprofits) - 1 
(cf_baseline_wedges.aggregates.total_HH_income_noprofits / aggregates_baseline_wedges.total_HH_income_noprofits) - 1 

# If include profits: baseline costs are around 4.0% - 4.5% 
((rental_rate - delta - 1)*initial_assets_wedges + cf_baseline_wedges.aggregates.total_HH_income) / ((rental_rate - delta - 1)*initial_assets_wedges + aggregates_baseline_wedges.total_HH_income) - 1 
(cf_baseline_wedges.aggregates.total_HH_income / aggregates_baseline_wedges.total_HH_income) - 1 

# How do profits change? Decline by 4.25% 
(cf_baseline_wedges.aggregates.total_net_profits / aggregates_baseline_wedges.total_net_profits) - 1 

# How do govt transfers change? increase by 57.48% 
(cf_baseline_wedges.aggregates.net_govt_transfers / aggregates_baseline_wedges.net_govt_transfers) - 1 


## Test: Check if can recover baseline equilibrium 
test_baseline_wedges = find_equilibrium_baseline_wedges(
	guess_w = 0.98, 
	guess_Y = aggregates_baseline_wedges.total_output*0.97, 
	distrib = z_star_distrib_wedges,
	mass = 1.0, 
	aggr_labor_supply = aggregates_baseline_wedges.total_labor_demand,
	tax = nothing, 
	param = wedge_param_baseline, 
	z_grid = z_grid_wedges, 
	mean_log_z = mean_log_z_wedges, 
	n_draws = 500, 
	update_param_w = 0.5, 
	update_param_Y = 0.5, 
	crit = 1e-6, 
	max_iter = 500)

# Yes! Is the same! They should be 1 and they are. 
test_baseline_wedges.w
test_baseline_wedges.Y / aggregates_baseline_wedges.total_output

#endregion

#region ### 10b) CF2 (No subsidies only) ###

#### Next, consider counterfactual in which only get rid of subsidies (but keep differential wedges) ####

# Implement how? Simply raise cost_level c to high enough value 

cf_no_subsidies_wedges = find_equilibrium_baseline_wedges(
	guess_w = 0.98, 
	guess_Y = aggregates_baseline_wedges.total_output*0.97, 
	distrib = z_star_distrib_wedges,
	mass = 1.0, 
	aggr_labor_supply = aggregates_baseline_wedges.total_labor_demand,
	tax = nothing, 
	param = define_param_wedges(wedge_param_baseline, connect_cost_level = wedge_param_baseline.connect_cost_level * 1e9),
	z_grid = z_grid_wedges, 
	mean_log_z = mean_log_z_wedges, 
	n_draws = 500, 
	update_param_w = 0.5, 
	update_param_Y = 0.5, 
	crit = 1e-6, 
	max_iter = 500)

## Gross output is slightly lower: about 2.86%!
cf_no_subsidies_wedges.Y / aggregates_baseline_wedges.total_output - 1

# wages? 
cf_no_subsidies_wedges.w - 1

# value added output costs: 1.43% higher without connections
cf_no_subsidies_wedges.aggregates.total_value_added_output / aggregates_baseline_wedges.total_value_added_output - 1 

# Main consumption effects: 8.20% higher consumption without connections (9.8% without capital) 
((rental_rate - delta - 1)*initial_assets_wedges + cf_no_subsidies_wedges.aggregates.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets_wedges + aggregates_baseline_wedges.total_HH_income_noprofits) - 1 
(cf_no_subsidies_wedges.aggregates.total_HH_income_noprofits / aggregates_baseline_wedges.total_HH_income_noprofits) - 1 

# If include profits: baseline costs are around 3.67% - 4.1% 
((rental_rate - delta - 1)*initial_assets_wedges + cf_no_subsidies_wedges.aggregates.total_HH_income) / ((rental_rate - delta - 1)*initial_assets_wedges + aggregates_baseline_wedges.total_HH_income) - 1 
(cf_no_subsidies_wedges.aggregates.total_HH_income / aggregates_baseline_wedges.total_HH_income) - 1 

# How do profits change? Decline by -4.55% 
(cf_no_subsidies_wedges.aggregates.total_net_profits / aggregates_baseline_wedges.total_net_profits) - 1 

# How do govt transfers change? increase by 56.99% 
(cf_no_subsidies_wedges.aggregates.net_govt_transfers / aggregates_baseline_wedges.net_govt_transfers) - 1 

#endregion

#region ### 11. Extension 2: Production Network ###

# load network functions & data
include("network_extension.jl")

# Solve baseline model (getting firm solution + sectoral expenditure + aggr labor supply + all other aggregates)
network_results_baseline = solve_industries_baseline(
	objects = network_baseline_objects, n_eps = 500, io_elast = industry_io_elasticities)

### check a couple of aggregates

# about 0.3% of total VA output spend on rent-seeking (This is comparable to the baseline!)
network_results_baseline.aggregate_rent_seeking ./ network_results_baseline.aggregate_final_use
network_results_baseline.aggregate_rent_seeking ./ sum(network_results_baseline.sectoral_expenditure)
network_results_baseline.sectoral_rent_seeking ./ network_results_baseline.aggregate_final_use
network_results_baseline.sectoral_rent_seeking ./ sum(network_results_baseline.sectoral_rent_seeking) # about 55% of rent-seeking comes from Food!

# how much tax revenue spent on subsidizing connected firms? 30%!! Thats very comparable to baseline!!  
network_results_baseline.total_subsidies / network_results_baseline.gross_tax_revenues
network_results_baseline.total_subsidies ./ network_results_baseline.aggregate_final_use # 5.5% of GDP spent on subsidies 
network_results_baseline.sectoral_subsidies / network_results_baseline.aggregate_final_use # align relatively closely with amount of rent-seeking
network_results_baseline.sectoral_subsidies / network_results_baseline.total_subsidies # align relatively closely with amount of rent-seeking


# how important are indiv parts in HH income? Profits (39.8%), net govt transfers (16.4%), labor income (43.8%) 
(network_results_baseline.total_HH_income - network_results_baseline.total_HH_income_noprofits) / network_results_baseline.total_HH_income 
network_results_baseline.net_govt_transfers / network_results_baseline.total_HH_income
network_results_baseline.aggregate_labor_demand / network_results_baseline.total_HH_income

## Sectoral profit shares are not concentrated, but fairly well dispersed (so higher profits more driven by lower eta_tilde in most sectors!) 
network_results_baseline.sectoral_total_profits / network_results_baseline.total_profits

# get initial assets network
initial_assets_network = network_results_baseline.aggregate_capital_demand

##### Now solve for main counterfactual: shutting down connections #####

# Need to prepare a couple of inputs 
aggregate_labor_supply_network = network_results_baseline.aggregate_labor_demand
alphas_tilde_network = zeros(n_sectors)
betas_tilde_network = zeros(n_sectors)
gammas_tilde_network = zeros(n_sectors)
sigmas_network = zeros(n_sectors)
industry_shares_network = zeros(n_sectors) 
for sector in 1:n_sectors
	alphas_tilde_network[sector] = network_baseline_objects.param[sector].alpha_tilde
	betas_tilde_network[sector] = network_baseline_objects.param[sector].beta_tilde
	gammas_tilde_network[sector] = network_baseline_objects.param[sector].gamma_tilde
	sigmas_network[sector] = network_baseline_objects.param[sector].sigma
	industry_shares_network[sector] = network_baseline_objects.param[sector].share_industry
end 
etas_tilde_network = alphas_tilde_network + betas_tilde_network + gammas_tilde_network

### Need to move from z_star_baseline_grid to z_grid primitive

# Step 1: Recover correct z_star (not z_star_tilde) by taking out sectoral intermediate prices 
network_z_star_grid_primitive = network_baseline_objects.z_star_grid .* (network_results_baseline.sectoral_interm_prices.^gammas_tilde_network)'

# Step 2: Recover z_grid 
network_z_grid_primitive = ( network_z_star_grid_primitive ./ (network_results_baseline.sectoral_prices .* (network_results_baseline.sectoral_output.^(1 ./ sigmas_network)))').^((sigmas_network ./ (sigmas_network .- 1))')

### Get aggregated z by sector 
sectoral_z_aggr_network = (sum( network_z_grid_primitive.^((sigmas_network .- 1)') .* network_baseline_objects.z_star_distrib .* industry_shares_network', dims = 1))[1,:]

## Now define equilibrium conditions
function find_eq_industries_cf_noC(x)

    # extract objects (make sure they are always positive!) 
    sectoral_prices = exp.(copy(x[1:n_sectors]))
    sectoral_quantities = exp.(copy(x[(n_sectors + 1):(2*n_sectors)]))

	# construct implied aggregate price index (which may be different from 1)
	aggr_price = prod((network_nu ./ sectoral_prices).^(network_nu))
	
    # get wage 
    wage = ((1-vat_tax)/aggregate_labor_supply_network)*sum( betas_tilde_network .* sectoral_prices .* sectoral_quantities )

	# get sectoral_input_prices 
    sectoral_input_prices = (1 ./ prod( (industry_io_elasticities[:,:]./sectoral_prices).^industry_io_elasticities[:,:], dims = 2))[:,1]

	# construct x_j_star
	x_j_star = (((1-vat_tax) .* alphas_tilde_network ./ (rental_rate .- 1.0)).^alphas_tilde_network).*(((1-vat_tax) .* betas_tilde_network ./ wage).^betas_tilde_network).*((gammas_tilde_network ./ sectoral_input_prices).^gammas_tilde_network)

    # get final Y 
	final_expenditure = sum(sectoral_prices .* sectoral_quantities) - sum( sum( industry_io_elasticities .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network, dims = 1)' )
    Y_F = final_expenditure / aggr_price

    # Verify guesses using sectoral supply and demand equations 
    sectoral_supply_diff = ( (x_j_star.^sigmas_network).*sectoral_z_aggr_network ).^(1 ./ (1 .- sigmas_network)) .- sectoral_prices
    sectoral_demand_diff = network_nu .* Y_F .+ sum( industry_io_elasticities .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network, dims = 1)[1,:] .- (sectoral_prices .* sectoral_quantities)

    return [sectoral_supply_diff; sectoral_demand_diff]
end 

# Find solution (starting with old eq as starting guess)
root_eq_industries_cf_noC = nlsolve(find_eq_industries_cf_noC, log.([network_results_baseline.sectoral_prices; network_results_baseline.sectoral_output])) # method = :newton #method = :trust_region, method = :newton (doesnt work: method = :anderson)
root_eq_industries_cf_noC.zero 

# check solution: looks good! 
find_eq_industries_cf_noC(root_eq_industries_cf_noC.zero)

## Check that solution satisfies aggregate price normalization: It does!! 
prod((network_nu ./ exp.(root_eq_industries_cf_noC.zero[1:n_sectors])).^(network_nu))
prod((network_nu ./ network_results_baseline.sectoral_prices).^(network_nu)) ## this is exactly 1 

## Get aggregates from this equilibrium
function eq_industries_cf_noC(x, io_elast = industry_io_elasticities)

    # extract objects (make sure they are always positive!) 
    sectoral_prices = exp.(copy(x[1:n_sectors]))
    sectoral_quantities = exp.(copy(x[(n_sectors + 1):(2*n_sectors)]))

	# construct implied aggregate price index (which may be different from 1)
	aggr_price = prod((network_nu ./ sectoral_prices).^(network_nu))
	
    # get wage 
    wage = ((1-vat_tax)/aggregate_labor_supply_network)*sum( betas_tilde_network .* sectoral_prices .* sectoral_quantities )

	# get sectoral_input_prices 
    sectoral_input_prices = (1 ./ prod( (io_elast[:,:]./sectoral_prices).^io_elast[:,:], dims = 2))[:,1]

	# construct x_j_star
	x_j_star = (((1-vat_tax) .* alphas_tilde_network ./ (rental_rate .- 1.0)).^alphas_tilde_network).*(((1-vat_tax) .* betas_tilde_network ./ wage).^betas_tilde_network).*((gammas_tilde_network ./ sectoral_input_prices).^gammas_tilde_network)

    # get final Y 
	final_expenditure = sum(sectoral_prices .* sectoral_quantities) - sum( sum( io_elast .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network, dims = 1)' )
    Y_F = final_expenditure / aggr_price

	# aggregate capital 
	aggregate_capital_demand = sum( alphas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network)

	# aggregate VAT revenue 
	aggregate_vat_revenue = vat_tax * sum( (1 .- gammas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network )

	# aggregate CIT revenue 
	aggregate_cit_revenue = profit_tax * (1 - vat_tax) * sum( (1 .- etas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network )

	# aggregate govt revenues 
	aggregate_govt_revenue = aggregate_vat_revenue + aggregate_cit_revenue

	# aggregate profits 
	aggregate_profits = (1 - profit_tax) * (1 - vat_tax) * sum( (1 .- etas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network )

    # HH consumption under closed economy 
    total_HH_consumption_closed = Y_F - delta*aggregate_capital_demand  # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0)    

    # compute total HH income (not including capital)
    total_HH_income = aggregate_govt_revenue + aggregate_profits + wage*aggregate_labor_supply_network
    total_HH_income_noprofits = aggregate_govt_revenue + wage*aggregate_labor_supply_network

    ### return results ### 
    return (
		Y_F = Y_F, wage = wage, 
        aggregate_capital_demand = aggregate_capital_demand, 
        total_HH_income = total_HH_income, 
        total_HH_income_noprofits = total_HH_income_noprofits,
        total_HH_consumption_closed = total_HH_consumption_closed, 
		aggregate_govt_revenue = aggregate_govt_revenue, 
		aggregate_profits = aggregate_profits, 
		sectoral_prices = sectoral_prices,
		sectoral_quantities = sectoral_quantities,
		sectoral_revenue = sectoral_prices .* sectoral_quantities
        )
end

# Get equilibrium 
results_eq_industries_cf_noC = eq_industries_cf_noC(root_eq_industries_cf_noC.zero)

## Aggregate final use (GDP) is about 3.2% smaller in this economy! 
results_eq_industries_cf_noC.Y_F / (network_results_baseline.aggregate_final_use - network_results_baseline.aggregate_rent_seeking) - 1

## How about wages? Wages are about 7.67% lower! 
results_eq_industries_cf_noC.wage - 1 

# Main consumption effects: 1.5% higher (1.3% higher without assets)
(results_eq_industries_cf_noC.total_HH_income_noprofits) / (network_results_baseline.total_HH_income_noprofits) - 1
((rental_rate - delta - 1)*initial_assets_network + results_eq_industries_cf_noC.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets_network + network_results_baseline.total_HH_income_noprofits) - 1 

# If include profits: Gains from connections are negative (so profits are important here!!) 
(results_eq_industries_cf_noC.total_HH_income / network_results_baseline.total_HH_income) - 1 
((rental_rate - delta - 1)*initial_assets_network + results_eq_industries_cf_noC.total_HH_income) / ((rental_rate - delta - 1)*initial_assets_network + network_results_baseline.total_HH_income) - 1 

# Profit effects? -17.82% (these are really huge now!) 
(results_eq_industries_cf_noC.aggregate_profits / network_results_baseline.total_profits) - 1 

# Govt transfers? +26.06
(results_eq_industries_cf_noC.aggregate_govt_revenue / network_results_baseline.net_govt_transfers) - 1 


### Understand a bit more why output and consumption effects are smaller now. 

# Check how sectoral sizes shift in CF: Correlate PY by sector in baseline vs. CF 
Plots.scatter(
	log.(network_results_baseline.sectoral_revenue), log.(results_eq_industries_cf_noC.sectoral_revenue), 
	xlabel = "Sectoral log size (Baseline)", 
	ylabel = "Sectoral log size (CF no connections)", 
	legend = false, 
	annotations = [
		(log.(network_results_baseline.sectoral_revenue)[1] + 0.05, log.(results_eq_industries_cf_noC.sectoral_revenue)[1] + 0.14, Plots.text("Chemicals",8)),
		(log.(network_results_baseline.sectoral_revenue)[2] + 0.05, log.(results_eq_industries_cf_noC.sectoral_revenue)[2] + 0.18, Plots.text("Machinery",8)),
		(log.(network_results_baseline.sectoral_revenue)[3], log.(results_eq_industries_cf_noC.sectoral_revenue)[3] + 0.18, Plots.text("Food",8)),
		(log.(network_results_baseline.sectoral_revenue)[4], log.(results_eq_industries_cf_noC.sectoral_revenue)[4] + 0.18, Plots.text("Minerals",8)),
		(log.(network_results_baseline.sectoral_revenue)[5], log.(results_eq_industries_cf_noC.sectoral_revenue)[5] + 0.18, Plots.text("Wood",8)),
		(log.(network_results_baseline.sectoral_revenue)[6], log.(results_eq_industries_cf_noC.sectoral_revenue)[6] + 0.18, Plots.text("Textiles",8))], 
	dpi = 200
	)

# Determine plot range (min and max from both X1 and X2)
min_val = min(minimum(log.(network_results_baseline.sectoral_revenue)), minimum(log.(results_eq_industries_cf_noC.sectoral_revenue)))
max_val = max(maximum(log.(network_results_baseline.sectoral_revenue)), maximum(log.(results_eq_industries_cf_noC.sectoral_revenue)))

# Add the 45-degree line (y = x)
Plots.plot!([min_val, max_val], [min_val, max_val], 
    label = "45° Line", 
    linestyle = :dash, 
    linecolor = :black, 
    linewidth = 2)
savefig(joinpath(dout,"results_network_sectoral_size_distortions.png"))

## Find that compared to the undistorted economy: Food & Textiles are distorted downward, while remaining sectors are distorted upward 


#endregion

#region ### 11. Further Decompositions: Direct vs. Indirect effect! ###

## To isolate effect of production network, lets recompute previous distorted economy without production network
## Then shut down connections in that economy & compare to differential effects in economy with production network 

# Solve baseline model with no IO (getting firm solution + sectoral expenditure + aggr labor supply + all other aggregates)
network_results_baseline_noIO = solve_industries_baseline(
	objects = network_baseline_objects, n_eps = 500, io_elast = Matrix{Float64}(I, 6, 6))

### check a couple of aggregates

# still 0.3% of total VA output spend on rent-seeking (completely unaffected by IO network)
network_results_baseline_noIO.aggregate_rent_seeking ./ network_results_baseline_noIO.aggregate_final_use

# Also completely unaffected: 30.4% of tax revenue spent on subsidizing connected firms
network_results_baseline_noIO.total_subsidies / network_results_baseline_noIO.gross_tax_revenues

# get initial assets network
initial_assets_network_noIO = network_results_baseline_noIO.aggregate_capital_demand

##### Now solve for main counterfactual: shutting down connections #####

# Need to prepare a couple of inputs 
aggregate_labor_supply_network_noIO = network_results_baseline_noIO.aggregate_labor_demand

### Need to move from z_star_baseline_grid to z_grid primitive

# Step 1: Recover correct z_star (not z_star_tilde) by taking out sectoral intermediate prices (this is different now!!)
network_noIO_z_star_grid_primitive = network_baseline_objects.z_star_grid .* (network_results_baseline_noIO.sectoral_interm_prices.^gammas_tilde_network)'

# Step 2: Recover z_grid 
network_noIO_z_grid_primitive = ( network_noIO_z_star_grid_primitive ./ (network_results_baseline_noIO.sectoral_prices .* (network_results_baseline_noIO.sectoral_output.^(1 ./ sigmas_network)))').^((sigmas_network ./ (sigmas_network .- 1))')

### Get aggregated z by sector 
sectoral_z_aggr_network_noIO = (sum( network_noIO_z_grid_primitive.^((sigmas_network .- 1)') .* network_baseline_objects.z_star_distrib .* industry_shares_network', dims = 1))[1,:]


### Solve for CF economy in which there are no connections + no IO linkages (identity matrix for IO)

# Shut down IO network 
function find_eq_industries_cf_noC_noIO(x)

    # extract objects (make sure they are always positive!) 
    sectoral_prices = exp.(copy(x[1:n_sectors]))
    sectoral_quantities = exp.(copy(x[(n_sectors + 1):(2*n_sectors)]))

	# construct implied aggregate price index (which may be different from 1)
	aggr_price = prod((network_nu ./ sectoral_prices).^(network_nu))
	
    # get wage 
    wage = ((1-vat_tax)/aggregate_labor_supply_network_noIO)*sum( betas_tilde_network .* sectoral_prices .* sectoral_quantities )

	# IO elasticities
	io_elast = Matrix{Float64}(I, 6, 6)

	# get sectoral_input_prices 
    sectoral_input_prices = (1 ./ prod( (io_elast[:,:]./sectoral_prices).^io_elast[:,:], dims = 2))[:,1]

	# construct x_j_star
	x_j_star = (((1-vat_tax) .* alphas_tilde_network ./ (rental_rate .- 1.0)).^alphas_tilde_network).*(((1-vat_tax) .* betas_tilde_network ./ wage).^betas_tilde_network).*((gammas_tilde_network ./ sectoral_input_prices).^gammas_tilde_network)

    # get final Y 
	final_expenditure = sum(sectoral_prices .* sectoral_quantities) - sum( sum( io_elast .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO, dims = 1)' )
    Y_F = final_expenditure / aggr_price

    # Verify guesses using sectoral supply and demand equations 
    sectoral_supply_diff = ( (x_j_star.^sigmas_network).*sectoral_z_aggr_network_noIO ).^(1 ./ (1 .- sigmas_network)) .- sectoral_prices
    sectoral_demand_diff = network_nu .* Y_F .+ sum( io_elast .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO, dims = 1)[1,:] .- (sectoral_prices .* sectoral_quantities)

    return [sectoral_supply_diff; sectoral_demand_diff]
end 

# Find solution (starting with old eq as starting guess)
root_eq_industries_cf_noC_noIO = nlsolve(find_eq_industries_cf_noC_noIO, log.([network_results_baseline_noIO.sectoral_prices; network_results_baseline_noIO.sectoral_output])) # method = :newton #method = :trust_region, method = :newton (doesnt work: method = :anderson)
root_eq_industries_cf_noC_noIO.zero 

# check solution: looks good (this should be all close to zero)
find_eq_industries_cf_noC_noIO(root_eq_industries_cf_noC_noIO.zero)

## Check that solution satisfies aggregate price normalization: It does!! 
prod((network_nu ./ exp.(root_eq_industries_cf_noC_noIO.zero[1:n_sectors])).^(network_nu))

## Now see how this equilibrium differs 

## Get aggregates from this equilibrium
function eq_industries_cf_noC_noIO(; x, io_elast = Matrix{Float64}(I, 6, 6), tax = vat_tax)

    # extract objects (make sure they are always positive!) 
    sectoral_prices = exp.(copy(x[1:n_sectors]))
    sectoral_quantities = exp.(copy(x[(n_sectors + 1):(2*n_sectors)]))

	# construct implied aggregate price index (which may be different from 1)
	aggr_price = prod((network_nu ./ sectoral_prices).^(network_nu))
	
    # get wage 
    wage = ((1-tax)/aggregate_labor_supply_network_noIO)*sum( betas_tilde_network .* sectoral_prices .* sectoral_quantities )

	# get sectoral_input_prices 
    sectoral_input_prices = (1 ./ prod( (io_elast[:,:]./sectoral_prices).^io_elast[:,:], dims = 2))[:,1]

	# construct x_j_star
	x_j_star = (((1-tax) .* alphas_tilde_network ./ (rental_rate .- 1.0)).^alphas_tilde_network).*(((1-tax) .* betas_tilde_network ./ wage).^betas_tilde_network).*((gammas_tilde_network ./ sectoral_input_prices).^gammas_tilde_network)

    # get final Y 
	final_expenditure = sum(sectoral_prices .* sectoral_quantities) - sum( sum( io_elast .* gammas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO, dims = 1)' )
    Y_F = final_expenditure / aggr_price

	# aggregate capital 
	aggregate_capital_demand = sum( alphas_tilde_network .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO)

	# aggregate VAT revenue 
	aggregate_vat_revenue = tax * sum( (1 .- gammas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO )

	# aggregate CIT revenue 
	aggregate_cit_revenue = profit_tax * (1 - tax) * sum( (1 .- etas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO )

	# aggregate govt revenues 
	aggregate_govt_revenue = aggregate_vat_revenue + aggregate_cit_revenue

	# aggregate profits 
	aggregate_profits = (1 - profit_tax) * vat_tax * sum( (1 .- etas_tilde_network) .* (x_j_star.^sigmas_network) .* (sectoral_prices.^sigmas_network) .* sectoral_quantities .* sectoral_z_aggr_network_noIO )

    # HH consumption under closed economy 
    total_HH_consumption_closed = Y_F - delta*aggregate_capital_demand  # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0)    

    # compute total HH income (not including capital)
    total_HH_income = aggregate_govt_revenue + aggregate_profits + wage*aggregate_labor_supply_network_noIO
    total_HH_income_noprofits = aggregate_govt_revenue + wage*aggregate_labor_supply_network_noIO

    ### return results ### 
    return (
		Y_F = Y_F, wage = wage, 
        aggregate_capital_demand = aggregate_capital_demand, 
        total_HH_income = total_HH_income, 
        total_HH_income_noprofits = total_HH_income_noprofits,
        total_HH_consumption_closed = total_HH_consumption_closed, 
		aggregate_govt_revenue = aggregate_govt_revenue, 
		aggregate_profits = aggregate_profits, 
		sectoral_prices = sectoral_prices,
		sectoral_quantities = sectoral_quantities,
		sectoral_revenue = sectoral_prices .* sectoral_quantities
        )
end

# Get equilibrium 
results_eq_industries_cf_noC_noIO = eq_industries_cf_noC_noIO(x = root_eq_industries_cf_noC_noIO.zero)

## Aggregate final use (GDP) is about 2.67% smaller in this economy! 
results_eq_industries_cf_noC_noIO.Y_F / (network_results_baseline_noIO.aggregate_final_use - network_results_baseline_noIO.aggregate_rent_seeking) - 1

## How about wages? Wages are about 7.67% lower! 
results_eq_industries_cf_noC_noIO.wage - 1 

# Main consumption effects: 1.68% higher (1.4% higher without assets)
(results_eq_industries_cf_noC_noIO.total_HH_income_noprofits) / (network_results_baseline_noIO.total_HH_income_noprofits) - 1
((rental_rate - delta - 1)*initial_assets_network_noIO + results_eq_industries_cf_noC_noIO.total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets_network_noIO + network_results_baseline_noIO.total_HH_income_noprofits) - 1 

# If include profits: Gains from connections are negative (so profits are important here!!) 
(results_eq_industries_cf_noC_noIO.total_HH_income / network_results_baseline_noIO.total_HH_income) - 1 
((rental_rate - delta - 1)*initial_assets_network_noIO + results_eq_industries_cf_noC_noIO.total_HH_income) / ((rental_rate - delta - 1)*initial_assets_network_noIO + network_results_baseline_noIO.total_HH_income) - 1 

# Profit effects? -90.84% (these are really huge now!) 
(results_eq_industries_cf_noC_noIO.aggregate_profits / network_results_baseline_noIO.total_profits) - 1 

# Govt transfers? +26.64
(results_eq_industries_cf_noC_noIO.aggregate_govt_revenue / network_results_baseline_noIO.net_govt_transfers) - 1 


### What are the indirect effects then? 

# Difference in differences! 

#endregion

#region ### Appendix: Costs with different values for sigma ###

## Create grid of sigmas 

sigma_grid = collect( range(baseline_param.sigma, stop = 25, length = 15) )

# save results in vector
results_sigma_grid = repeat([deepcopy(baseline_param),deepcopy(baseline_cf_eq_SS_goods),deepcopy(baseline_cf_eq_SS_goods_aggregates)], 1, 15)

# save results in vectors 
results_sigma_grid_output = zeros(15)
results_sigma_grid_consumption = zeros(15)

# loop through equilibria (parallelized)  
Threads.@threads for i in eachindex(sigma_grid) # 
	
	# show progress
	println("Solving entry ", i)

	# define param first 
	results_sigma_grid[1,i] = define_param(baseline_param, sigma = sigma_grid[i]) 

	# find corresponding CF equilibrium (in SS)
	results_sigma_grid[2,i] = find_equilibrium_cf_goods(
		guess_w = 1.0, guess_Y = baseline_aggregates_goods.total_output, 
		param = results_sigma_grid[1,i], 
		aggr_L = aggregate_labor_supply,
		update_param_w = 0.5, update_param_Y = 0.9, max_iter = 150, max_iter_w = 150, crit = 1e-12)

	# save output 
	results_sigma_grid_output[i] = (results_sigma_grid[2,i].output_VA / (baseline_aggregates_goods.total_output - baseline_aggregates_goods.productive_intermediates - baseline_aggregates_goods.total_rent_seeking)) - 1

	## Need z_grid

	# compute z grid based on baseline results & choice of sigma 
	x_bar_sigma = (baseline_aggregates_goods.total_output)^(1/results_sigma_grid[1,i].sigma)
	z_grid_sigma = (z_star_grid ./ x_bar_sigma).^(results_sigma_grid[1,i].sigma/(results_sigma_grid[1,i].sigma-1))

	# now compute aggregates to compute consumption 
	results_sigma_grid[3,i] = compute_aggregates_noC_goods(
		w = results_sigma_grid[2,i].w, 
		distrib = results_sigma_grid[2,i].distrib_results.SS_distrib, 
		mass_firms = results_sigma_grid[2,i].mass, 
		mass_entry = results_sigma_grid[2,i].distrib_results.mass_entry, 
		exit_proba = results_sigma_grid[2,i].VF_results.exit_proba, 
		within_period_choices = compute_profits_NC(w = results_sigma_grid[2,i].w, Y = results_sigma_grid[2,i].Y, param = results_sigma_grid[1,i], z_grid = z_grid_sigma), 
		VF_results = results_sigma_grid[2,i].VF_results, 
		param = results_sigma_grid[1,i])

	# get consumption effect 
	results_sigma_grid_consumption[i] = ((rental_rate - delta - 1)*initial_assets + results_sigma_grid[3,i].total_HH_income_noprofits) / ((rental_rate - delta - 1)*initial_assets + baseline_aggregates_goods.total_HH_income_noprofits) - 1 

end 

# express everything in terms of relative gains/losses 
results_sigma_grid_output_relative = results_sigma_grid_output/results_sigma_grid_output[1] .- 1
results_sigma_grid_consumption_relative = results_sigma_grid_consumption/results_sigma_grid_consumption[1] .- 1

## print results and see optimum (Baseline level is not far from optimum!) 
Plots.plot(
	sigma_grid, results_sigma_grid_output_relative, 
	color = "red", lw = 2, 
	label = "Output (VA)",
	xlabel = "Elasticity of substitution",
	ylabel = "% deviation from baseline", dpi = 200, 
	thickness_scaling = 1.3, 
	linestyle = :dash)
Plots.plot!(sigma_grid, results_sigma_grid_consumption_relative, color = "blue", lw = 2, label = "Consumption")
hline!([0.0], linestyle=:dash, color=:gray, label = "")
yticks_vals = 0.00:0.02:0.16  # Choose tick positions
yticks_labels = ["$(round(Int, 100 * y))%" for y in yticks_vals]
Plots.yticks!(yticks_vals, yticks_labels)
savefig(joinpath(dout,"results_aggregate_costs_different_sigma.png"))


#endregion 
